import os
import sys
import time
import itertools
import logging

print("aa-1")

import numpy as np
np.random.seed(0)
import pandas as pd

print("aa-2")

import gym
from gym import wrappers

print("aa-22")
import tensorflow as tf

print("aa-3")

from tensorflow import keras
from PIL import Image
import matplotlib.pyplot as plt

print("aa-4")

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG,
        format='%(asctime)s [%(levelname)s] %(message)s')


env_spec_id = 'BreakoutDeterministic-v4'
# env_spec_id = 'PongDeterministic-v4'
# env_spec_id = 'SeaquestDeterministic-v4'
# env_spec_id = 'SpaceInvadersDeterministic-v4'
# env_spec_id = 'BeamRiderDeterministic-v4'
env = gym.make(env_spec_id)
env = wrappers.Monitor(env, "./gym-results", force=True)
print('观测空间 = {}'.format(env.observation_space))
print('动作空间 = {}'.format(env.action_space))
#print('回合最大步数 = {}'.format(env._max_episode_steps))
env.seed(0)



class DQNReplayer:
    def __init__(self, capacity):
        self.memory = pd.DataFrame(index=range(capacity),
                columns=['observation', 'action', 'reward',
                'next_observation', 'done'])
        self.i = 0
        self.count = 0
        self.capacity = capacity
    
    def store(self, *args):
        self.memory.loc[self.i] = args
        self.i = (self.i + 1) % self.capacity
        self.count = min(self.count + 1, self.capacity)
        
    def sample(self, size):
        indices = np.random.choice(self.count, size=size)
        return tuple(np.stack(self.memory.loc[indices, field]) for \
                field in self.memory.columns)


class DQNAgent:
    def __init__(self, env, input_shape, learning_rate=0.00025,
            load_path=None, gamma=0.99,
            replay_memory_size=1000000, batch_size=32,
            replay_start_size=0,
            epsilon=1., epsilon_decrease_rate=9e-7, min_epsilon=0.1,
            random_inital_steps=0,
            clip_reward=True, rescale_state=True,
            update_freq=1, target_network_update_freq=1):
        
        self.action_n = env.action_space.n
        self.gamma = gamma
        
        # 经验回放参数
        self.replay_memory_size = replay_memory_size
        self.replay_start_size = replay_start_size
        self.batch_size = batch_size
        self.replayer = DQNReplayer(replay_memory_size)
        
        # 图像输入参数
        self.img_shape = (input_shape[-1], input_shape[-2])
        self.img_stack = input_shape[-3]
        
        # 探索参数
        self.epsilon = epsilon
        self.epsilon_decrease_rate = epsilon_decrease_rate
        self.min_epsilon = min_epsilon
        self.random_inital_steps = random_inital_steps
        
        self.clip_reward = clip_reward
        self.rescale_state = rescale_state
        
        self.update_freq = update_freq
        self.target_network_update_freq = target_network_update_freq
        
        # 评估网络
        self.evaluate_net = self.build_network(
                input_shape=input_shape, output_size=self.action_n,
                conv_activation=tf.nn.relu,
                fc_hidden_sizes=[512,], fc_activation=tf.nn.relu,
                learning_rate=learning_rate, load_path=load_path)
        self.evaluate_net.summary() # 输出网络结构
        # 目标网络
        self.target_net = self.build_network(
                input_shape=input_shape, output_size=self.action_n,
                conv_activation=tf.nn.relu,
                fc_hidden_sizes=[512,], fc_activation=tf.nn.relu,
                )
        self.update_target_network()
        
        # 初始化计数值
        self.step = 0
        self.fit_count = 0


    def build_network(self, input_shape, output_size, conv_activation,
            fc_hidden_sizes, fc_activation, output_activation=None,
            learning_rate=0.001, load_path=None):
        # 网络输入格式: (样本, 通道, 行, 列)
        model = keras.models.Sequential()
        # tf 要求从 (样本, 通道, 行, 列) 改为 (样本, 行, 列, 通道)
        model.add(keras.layers.Permute((2, 3, 1), input_shape=input_shape))
        
        # 卷积层
        model.add(keras.layers.Conv2D(32, 8, strides=4,
                activation=conv_activation))
        model.add(keras.layers.Conv2D(64, 4, strides=2,
                activation=conv_activation))
        model.add(keras.layers.Conv2D(64, 3, strides=1,
                activation=conv_activation))
        
        model.add(keras.layers.Flatten())
        
        # 全连接层
        for hidden_size in fc_hidden_sizes:
            model.add(keras.layers.Dense(hidden_size,
                    activation=fc_activation))
        model.add(keras.layers.Dense(output_size,
                activation=output_activation))

        if load_path is not None:
            logging.info('载入网络权重 {}.'.format(load_path))
            model.load_weights(load_path)

        try: # tf2
            optimizer = keras.optimizers.RMSprop(learning_rate, 0.95,
                    momentum=0.95, epsilon=0.01)
        except: # tf1
            optimizer = tf.train.RMSPropOptimizer(learning_rate, 0.95,
                    momentum=0.95, epsilon=0.01)
        model.compile(loss=keras.losses.mse, optimizer=optimizer)
        return model
        
    def get_next_state(self, state=None, observation=None):
        img = Image.fromarray(observation, 'RGB') 
        img = img.resize(self.img_shape).convert('L') # 改大小,变灰度
        img = np.asarray(img.getdata(), dtype=np.uint8
                ).reshape(img.size[1], img.size[0]) # 转成 np.array
        
        # 堆叠图像
        if state is None:
            next_state = np.array([img,] * self.img_stack) # 初始化
        else:
            next_state = np.append(state[1:], [img,], axis=0)
        return next_state
    
    def decide(self, state, test=False, step=None):
        if step is not None and step < self.random_inital_steps:
            epsilon = 1.
        elif test:
            epsilon = 0.05
        else:
            epsilon = self.epsilon
        if np.random.rand() < epsilon:
            action = np.random.choice(self.action_n)
        else:
            if self.rescale_state:
                state = state / 128. - 1.
            q_values = self.evaluate_net.predict(state[np.newaxis])[0]
            action = np.argmax(q_values)
        return action

    def learn(self, state, action, reward, next_state, done):
        self.replayer.store(state, action, reward, next_state, done)

        self.step += 1
        
        if self.step % self.update_freq == 0 and \
                self.replayer.count >= self.replay_start_size:
            states, actions, rewards, next_states, dones = \
                    self.replayer.sample(self.batch_size) # 回放

            if self.rescale_state:
                states = states / 128. - 1.
                next_states = next_states / 128. - 1.
            if self.clip_reward:
                rewards = np.clip(rewards, -1., 1.)
            
            next_qs = self.target_net.predict(next_states)
            next_max_qs = next_qs.max(axis=-1)
            targets = self.evaluate_net.predict(states)
            targets[range(self.batch_size), actions] = rewards + \
                    self.gamma * next_max_qs * (1. - dones)

            h = self.evaluate_net.fit(states, targets, verbose=0)
            self.fit_count += 1
            
            if self.fit_count % 100 == 0:
                logging.info('训练 {}, 回合 {}, 存储大小 {}, 损失 {}' \
                        .format(self.fit_count, self.epsilon,
                        self.replayer.count, h.history['loss'][0]))
            
            if self.fit_count % self.target_network_update_freq == 0:
                self.update_target_network()
        
        # 更新 epsilon 的值：线性下降
        if self.step >= self.replay_start_size:
            self.epsilon = max(self.epsilon - self.epsilon_decrease_rate,
                               self.min_epsilon)

    def update_target_network(self): # 更新目标网络
        self.target_net.set_weights(self.evaluate_net.get_weights())
        logging.info('目标网络已更新')

    def save_network(self, path): # 保存网络
        dirname = os.path.dirname(save_path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
            logging.info('创建文件夹 {}'.format(dirname))
        self.evaluate_net.save_weights(path)
        logging.info('网络权重已保存 {}'.format(path))



def test(env, agent, episodes=100, render=False, verbose=True):
    steps, episode_rewards = [], []
    for episode in range(episodes):
        episode_reward = 0
        observation = env.reset()
        state = agent.get_next_state(None, observation)
        for step in itertools.count():
            if render:
                env.render()
            action = agent.decide(state, test=True, step=step)
            observation, reward, done, info = env.step(action)
            state = agent.get_next_state(state, observation)
            episode_reward += reward
            if done:
                break
        step += 1
        steps.append(step)
        episode_rewards.append(episode_reward)
        logging.info('[测试] 回合 {}: 步骤 {}, 奖励 {}, 步数 {}'
                .format(episode, step, episode_reward, np.sum(steps)))
            
    if verbose:
        logging.info('[测试小结] 步数: 平均 = {}, 最小 = {}, 最大 = {}.' \
                .format(np.mean(steps), np.min(steps), np.max(steps)))
        logging.info('[测试小结] 奖励: 平均 = {}, 最小 = {}, 最大 = {}' \
                .format(np.mean(episode_rewards), np.min(episode_rewards),
                np.max(episode_rewards)))
    return episode_rewards



render = False
load_path = None
save_path = './output/' + env.unwrapped.spec.id + '-' + \
        time.strftime('%Y%m%d-%H%M%S') + '/model.h5'

"""
Nature 文章使用的参数, 运行极慢, 请勿轻易尝试
"""
input_shape = (4, 110, 84) # 输入网络大小
batch_size = 32
replay_memory_size = 1000000
target_network_update_freq = 10000
gamma = 0.99
update_freq = 4 # 训练网络的间隔
learning_rate = 0.00025 # 优化器学习率
epsilon = 1. # 初始探索率
min_epsilon = 0.1 # 最终探索率
epsilon_decrease = 9e-7 # 探索减小速度
replay_start_size = 50000 # 开始训练前的经验数
random_inital_steps = 30 # 每个回合开始时随机步数
frames = 50000000 # 整个算法的总训练步数
test_freq = 50000 # 验证智能体的步数间隔
test_episodes = 100 # 每次验证智能体的回合数


"""
小规模参数, 运行时间数小时, 有一点点训练效果
"""
batch_size = 32
replay_memory_size = 50000
target_network_update_freq = 4000
replay_start_size = 10000
random_inital_steps = 30
frames = 100000
test_freq = 25000
test_episodes = 50


# """
# 超小规模参数, 数分钟即可运行完, 基本没有训练效果
# """
# batch_size = 6
# replay_memory_size = 5000
# target_network_update_freq = 80
# replay_start_size = 500
# random_inital_steps = 30
# frames = 7500
# test_freq = 2500
# test_episodes = 10


save_path = './output/BreakoutDeterministic-v4-20200415-070411/model.h5'
print(save_path)


test_agent = DQNAgent(env, input_shape=input_shape, load_path=save_path)
test_episode_rewards = test(env, test_agent, episodes=1, render=True)
print('平均回合奖励 = {}'.format(np.mean(test_episode_rewards)))

env.close()